template <typename T>
__device__ void softmax_loss_backward(T *prob, T *label, T *weights, int num, int spatial_dim, int prob_dim) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  int idx_sp = blockIdx.y;
  if (idx >= num || idx_sp >= spatial_dim)
    return;

  int y = static_cast<int>(label[idx * spatial_dim + idx_sp]);
  prob[idx * (spatial_dim*prob_dim) + y*spatial_dim + idx_sp] -= 1;

  if (weights != NULL) {
    T weight = weights[idx * (spatial_dim*prob_dim) + y*spatial_dim + idx_sp];
    int idx_base = idx * (spatial_dim*prob_dim) + idx_sp;
    for (int i = 0; i < prob_dim; ++i) {
      prob[idx_base + i*spatial_dim] *= weight;
    }
  }
}

extern "C" {
  __global__ void softmax_loss_backward_float(float *prob, float *label, float *weights, int num, int spatial_dim, int prob_dim) {
    softmax_loss_backward(prob, label, weights, num, spatial_dim, prob_dim);
  }
  __global__ void softmax_loss_backward_double(double *prob, double *label, double *weights, int num, int spatial_dim, int prob_dim) {
    softmax_loss_backward(prob, label, weights, num, spatial_dim, prob_dim);
  }
}

// vim: ft=cuda
